import os
import json
import math
import numpy as np
from PIL import Image
import cv2

import torch
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, IterableDataset
import torchvision.transforms.functional as TF

import pytorch_lightning as pl

import datasets
from models.ray_utils import (
    get_ortho_ray_directions_origins,
    get_ortho_rays,
    get_ray_directions,
)
from utils.misc import get_rank

from glob import glob
import PIL.Image


def camNormal2worldNormal(rot_c2w, camNormal):
    H, W, _ = camNormal.shape
    normal_img = np.matmul(
        rot_c2w[None, :, :], camNormal.reshape(-1, 3)[:, :, None]
    ).reshape([H, W, 3])

    return normal_img


def worldNormal2camNormal(rot_w2c, worldNormal):
    H, W, _ = worldNormal.shape
    normal_img = np.matmul(
        rot_w2c[None, :, :], worldNormal.reshape(-1, 3)[:, :, None]
    ).reshape([H, W, 3])

    return normal_img


def trans_normal(normal, RT_w2c, RT_w2c_target):
    normal_world = camNormal2worldNormal(np.linalg.inv(RT_w2c[:3, :3]), normal)
    normal_target_cam = worldNormal2camNormal(RT_w2c_target[:3, :3], normal_world)

    return normal_target_cam


def img2normal(img):
    return (img / 255.0) * 2 - 1


def normal2img(normal):
    return np.uint8((normal * 0.5 + 0.5) * 255)


def norm_normalize(normal, dim=-1):
    normal = normal / (np.linalg.norm(normal, axis=dim, keepdims=True) + 1e-6)

    return normal


def RT_opengl2opencv(RT):
    # Build the coordinate transform matrix from world to computer vision camera
    # R_world2cv = R_bcam2cv@R_world2bcam
    # T_world2cv = R_bcam2cv@T_world2bcam

    R = RT[:3, :3]
    t = RT[:3, 3]

    R_bcam2cv = np.asarray([[1, 0, 0], [0, -1, 0], [0, 0, -1]], np.float32)

    R_world2cv = R_bcam2cv @ R
    t_world2cv = R_bcam2cv @ t

    RT = np.concatenate([R_world2cv, t_world2cv[:, None]], 1)

    return RT


def normal_opengl2opencv(normal):
    H, W, C = np.shape(normal)
    # normal_img = np.reshape(normal, (H*W,C))
    R_bcam2cv = np.array([1, -1, -1], np.float32)
    normal_cv = normal * R_bcam2cv[None, None, :]

    print(np.shape(normal_cv))

    return normal_cv


def inv_RT(RT):
    RT_h = np.concatenate([RT, np.array([[0, 0, 0, 1]])], axis=0)
    RT_inv = np.linalg.inv(RT_h)

    return RT_inv[:3, :]


def load_a_prediction(
    root_dir,
    test_object,
    imSize,
    view_types,
    load_color=False,
    cam_pose_dir=None,
    normal_system="front",
    erode_mask=True,
    camera_type="ortho",
    cam_params=None,
):
    all_images = []
    all_normals = []
    all_normals_world = []
    all_masks = []
    all_color_masks = []
    all_poses = []
    all_w2cs = []
    directions = []
    ray_origins = []

    RT_front = np.loadtxt(
        glob(os.path.join(cam_pose_dir, "*_%s_RT.txt" % ("front")))[0]
    )  # world2cam matrix
    RT_front_cv = RT_opengl2opencv(RT_front)  # convert normal from opengl to opencv
    for idx, view in enumerate(view_types):
        print(os.path.join(root_dir, test_object))
        normal_filepath = os.path.join(
            root_dir, test_object, "normals_000_%s.png" % (view)
        )
        # Load key frame
        if load_color:  # use bgr
            image = np.array(
                PIL.Image.open(normal_filepath.replace("normals", "rgb")).resize(imSize)
            )[:, :, :3]

        normal = np.array(PIL.Image.open(normal_filepath).resize(imSize))
        mask = normal[:, :, 3]
        normal = normal[:, :, :3]

        color_mask = np.array(
            PIL.Image.open(
                os.path.join(
                    root_dir, test_object, "masked_colors/rgb_000_%s.png" % (view)
                )
            ).resize(imSize)
        )[:, :, 3]
        invalid_color_mask = color_mask < 255 * 0.5
        threshold = np.ones_like(image[:, :, 0]) * 250
        invalid_white_mask = (
            (image[:, :, 0] > threshold)
            & (image[:, :, 1] > threshold)
            & (image[:, :, 2] > threshold)
        )
        invalid_color_mask_final = invalid_color_mask & invalid_white_mask
        color_mask = (1 - invalid_color_mask_final) > 0

        # if erode_mask:
        #     kernel = np.ones((3, 3), np.uint8)
        #     mask = cv2.erode(mask, kernel, iterations=1)

        RT = np.loadtxt(
            os.path.join(cam_pose_dir, "000_%s_RT.txt" % (view))
        )  # world2cam matrix

        normal = img2normal(normal)

        normal[mask == 0] = [0, 0, 0]
        mask = mask > (0.5 * 255)
        if load_color:
            all_images.append(image)

        all_masks.append(mask)
        all_color_masks.append(color_mask)
        RT_cv = RT_opengl2opencv(RT)  # convert normal from opengl to opencv
        all_poses.append(inv_RT(RT_cv))  # cam2world
        all_w2cs.append(RT_cv)

        # whether to
        normal_cam_cv = normal_opengl2opencv(normal)

        if normal_system == "front":
            print("the loaded normals are defined in the system of front view")
            normal_world = camNormal2worldNormal(
                inv_RT(RT_front_cv)[:3, :3], normal_cam_cv
            )
        elif normal_system == "self":
            print("the loaded normals are in their independent camera systems")
            normal_world = camNormal2worldNormal(inv_RT(RT_cv)[:3, :3], normal_cam_cv)
        all_normals.append(normal_cam_cv)
        all_normals_world.append(normal_world)

        if camera_type == "ortho":
            origins, dirs = get_ortho_ray_directions_origins(W=imSize[0], H=imSize[1])
        elif camera_type == "pinhole":
            dirs = get_ray_directions(
                W=imSize[0],
                H=imSize[1],
                fx=cam_params[0],
                fy=cam_params[1],
                cx=cam_params[2],
                cy=cam_params[3],
            )
            origins = dirs  # occupy a position
        else:
            raise Exception("not support camera type")
        ray_origins.append(origins)
        directions.append(dirs)

        if not load_color:
            all_images = [normal2img(x) for x in all_normals_world]

    return (
        np.stack(all_images),
        np.stack(all_masks),
        np.stack(all_normals),
        np.stack(all_normals_world),
        np.stack(all_poses),
        np.stack(all_w2cs),
        np.stack(ray_origins),
        np.stack(directions),
        np.stack(all_color_masks),
    )


class OrthoDatasetBase:
    def setup(self, config, split):
        print("here\n\n")
        self.config = config
        self.split = split
        self.rank = get_rank()

        self.data_dir = self.config.root_dir
        self.object_name = self.config.scene
        self.scene = self.config.scene
        self.imSize = self.config.imSize
        self.load_color = True
        self.img_wh = [self.imSize[0], self.imSize[1]]
        self.w = self.img_wh[0]
        self.h = self.img_wh[1]
        self.camera_type = self.config.camera_type
        self.camera_params = self.config.camera_params  # [fx, fy, cx, cy]

        self.view_types = [
            "front",
            "front_right",
            "right",
            "back",
            "left",
            "front_left",
        ]

        self.view_weights = (
            torch.from_numpy(np.array(self.config.view_weights))
            .float()
            .to(self.rank)
            .view(-1)
        )
        self.view_weights = self.view_weights.view(-1, 1, 1).repeat(1, self.h, self.w)

        if self.config.cam_pose_dir is None:
            self.cam_pose_dir = "./datasets/fixed_poses"
        else:
            self.cam_pose_dir = self.config.cam_pose_dir

        (
            self.images_np,
            self.masks_np,
            self.normals_cam_np,
            self.normals_world_np,
            self.pose_all_np,
            self.w2c_all_np,
            self.origins_np,
            self.directions_np,
            self.rgb_masks_np,
        ) = load_a_prediction(
            self.data_dir,
            self.object_name,
            self.imSize,
            self.view_types,
            self.load_color,
            self.cam_pose_dir,
            normal_system="front",
            camera_type=self.camera_type,
            cam_params=self.camera_params,
        )

        self.has_mask = True
        self.apply_mask = self.config.apply_mask

        self.all_c2w = torch.from_numpy(self.pose_all_np)
        self.all_images = torch.from_numpy(self.images_np) / 255.0
        self.all_fg_masks = torch.from_numpy(self.masks_np)
        self.all_rgb_masks = torch.from_numpy(self.rgb_masks_np)
        self.all_normals_world = torch.from_numpy(self.normals_world_np)
        self.origins = torch.from_numpy(self.origins_np)
        self.directions = torch.from_numpy(self.directions_np)

        self.directions = self.directions.float().to(self.rank)
        self.origins = self.origins.float().to(self.rank)
        self.all_rgb_masks = self.all_rgb_masks.float().to(self.rank)
        self.all_c2w, self.all_images, self.all_fg_masks, self.all_normals_world = (
            self.all_c2w.float().to(self.rank),
            self.all_images.float().to(self.rank),
            self.all_fg_masks.float().to(self.rank),
            self.all_normals_world.float().to(self.rank),
        )


class OrthoDataset(Dataset, OrthoDatasetBase):
    def __init__(self, config, split):
        self.setup(config, split)

    def __len__(self):
        return len(self.all_images)

    def __getitem__(self, index):
        return {"index": index}


class OrthoIterableDataset(IterableDataset, OrthoDatasetBase):
    def __init__(self, config, split):
        self.setup(config, split)

    def __iter__(self):
        while True:
            yield {}


@datasets.register("ortho")
class OrthoDataModule(pl.LightningDataModule):
    def __init__(self, config):
        super().__init__()
        self.config = config

    def setup(self, stage=None):
        if stage in [None, "fit"]:
            self.train_dataset = OrthoIterableDataset(self.config, "train")
        if stage in [None, "fit", "validate"]:
            self.val_dataset = OrthoDataset(
                self.config, self.config.get("val_split", "train")
            )
        if stage in [None, "test"]:
            self.test_dataset = OrthoDataset(
                self.config, self.config.get("test_split", "test")
            )
        if stage in [None, "predict"]:
            self.predict_dataset = OrthoDataset(self.config, "train")

    def prepare_data(self):
        pass

    def general_loader(self, dataset, batch_size):
        sampler = None
        return DataLoader(
            dataset,
            num_workers=os.cpu_count(),
            batch_size=batch_size,
            pin_memory=True,
            sampler=sampler,
        )

    def train_dataloader(self):
        return self.general_loader(self.train_dataset, batch_size=1)

    def val_dataloader(self):
        return self.general_loader(self.val_dataset, batch_size=1)

    def test_dataloader(self):
        return self.general_loader(self.test_dataset, batch_size=1)

    def predict_dataloader(self):
        return self.general_loader(self.predict_dataset, batch_size=1)
